package com.bigfans.framework.dao.mybatis;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Map;
import java.util.Properties;

import com.bigfans.framework.utils.StringHelper;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.MappedStatement.Builder;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import com.bigfans.framework.model.PageContext;

@Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
                RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
                RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})})
public class PagerInterceptor implements Interceptor {

    private static String DEFAULT_PAGE_SQL_ID = ".*list$"; // 需要拦截的ID(正则匹配)
    private Dialect dialect;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 获取被代理的Executor
        Executor executor = (Executor) invocation.getTarget();
        Object[] args = invocation.getArgs();
        // 获取MappedStatement
        MappedStatement mappedStatement = (MappedStatement) args[0];
        // 获取执行的参数
        Object parameterObject = args[1];
        RowBounds rowBounds = (RowBounds) args[2];
        ResultHandler resultHandler = (ResultHandler) args[3];
        CacheKey cacheKey;
        BoundSql boundSql;
        if (args.length == 4) {
            // 4 个参数时
            boundSql = mappedStatement.getBoundSql(parameterObject);
            cacheKey = executor.createCacheKey(mappedStatement, parameterObject, rowBounds, boundSql);
        } else {
            // 6 个参数时
            cacheKey = (CacheKey) args[4];
            boundSql = (BoundSql) args[5];
        }

        // 如果不需要分页，则跳过次拦截器
        if (!mappedStatement.getId().matches(DEFAULT_PAGE_SQL_ID) || parameterObject == null) {
            return invocation.proceed();
        }

        Map<String, Object> params = (Map<String, Object>) parameterObject;
        Object start = params.get("start");
        Object pagesize = params.get("pagesize");
        if (start == null && pagesize == null) {
            return invocation.proceed();
        }
        String originalSql = boundSql.getSql().trim();
        // 重写sql
        String pageSql = dialect.getPagerSql(originalSql, Long.valueOf(start.toString()),
                Long.valueOf(pagesize.toString()));

        BoundSql pageBoundSql = new BoundSql(mappedStatement.getConfiguration(), pageSql,
                boundSql.getParameterMappings(), parameterObject);
        // 拷贝参数到心得BoundSql中
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                pageBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }

        Object result = executor.query(mappedStatement, parameterObject, rowBounds, resultHandler, cacheKey,
                pageBoundSql);
        // 获取是否需要分页标示
        Boolean pageable = Boolean.valueOf(String.valueOf(params.get("pageable")));
        if (!pageable) {
            return result;
        }
        if (result != null) {
            // 查询记录数
            Number count = this.getRowCount(executor.getTransaction().getConnection(), originalSql, mappedStatement, boundSql);
            PageContext.setDataCount(count.longValue());
        }
        return result;
    }

    private Number getRowCount(Connection conn, String originalSql, MappedStatement mappedStatement, BoundSql boundSql) {
        String countSql = dialect.getCountSql(originalSql);
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = conn.prepareStatement(countSql);
            BoundSql countBS = new BoundSql(mappedStatement.getConfiguration(), countSql,
                    boundSql.getParameterMappings(), boundSql.getParameterObject());
            ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), countBS);
            parameterHandler.setParameters(pstmt);

            rs = pstmt.executeQuery();
            if (rs.next()) {
                // 设置总记录数
                Long count = rs.getLong(1);
                return count;
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return 0;
    }

    private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        Builder builder = new Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());

        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());

        StringBuilder kps = new StringBuilder();
        String[] keyProperties = ms.getKeyProperties();
        for (String kp : keyProperties) {
            kps.append(kp).append(",");
        }
        builder.keyProperty(kps.toString());
        return builder.build();
    }

    /**
     * 复制BoundSql对象
     */
    private BoundSql copyFromBoundSql(MappedStatement ms, BoundSql boundSql, String sql) {
        BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), sql, boundSql.getParameterMappings(),
                boundSql.getParameterObject());
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }
        return newBoundSql;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        String dialectType = properties.getProperty("dialect");
        if (StringHelper.isEmpty(dialectType) || dialectType.equals("mysql")) {
            dialect = new MySqlDialect();
            return;
        }
        throw new RuntimeException("mybatis plugin : unsupport dialect");
    }

}
